from torch.utils.data import Dataset


class DynamicDataset(Dataset):
    def __init__(self, *args):
        super(DynamicDataset).__init__()
        self.args = args
    
    def __len__(self):
        return len(self.args[0])

    def __getitem__(self, index):
        return tuple(arg[index] for arg in self.args)
